{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {},
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "import pandas as pd\n",
    "import pymysql\n",
    "import os, sys\n",
    "from sklearn.ensemble import RandomForestRegressor, GradientBoostingRegressor, AdaBoostRegressor\n",
    "from sklearn.model_selection import train_test_split\n",
    "from matplotlib import pyplot as plt\n",
    "import warnings\n",
    "\n",
    "warnings.filterwarnings('ignore')\n",
    "basePath = \"..\\\\dataset\""
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>REPORT_DATE</th>\n",
       "      <th>BANNER_NAME</th>\n",
       "      <th>PRODUCT_BKEY</th>\n",
       "      <th>QLI</th>\n",
       "      <th>LAST_1_MONTH_AVG</th>\n",
       "      <th>LAST_1_MONTH_MIN</th>\n",
       "      <th>LAST_1_MONTH_MAX</th>\n",
       "      <th>LAST_1_MONTH_SUM</th>\n",
       "      <th>LAST_1_MONTH_MID</th>\n",
       "      <th>LAST_1_MONTH_MID25</th>\n",
       "      <th>...</th>\n",
       "      <th>LAST_6_MONTH_MID</th>\n",
       "      <th>LAST_6_MONTH_MID25</th>\n",
       "      <th>LAST_6_MONTH_MID75</th>\n",
       "      <th>LAST_12_MONTH_AVG</th>\n",
       "      <th>LAST_12_MONTH_MIN</th>\n",
       "      <th>LAST_12_MONTH_MAX</th>\n",
       "      <th>LAST_12_MONTH_SUM</th>\n",
       "      <th>LAST_12_MONTH_MID</th>\n",
       "      <th>LAST_12_MONTH_MID25</th>\n",
       "      <th>LAST_12_MONTH_MID75</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>0</th>\n",
       "      <td>2016-09-01</td>\n",
       "      <td>Carrefour</td>\n",
       "      <td>XCN0401200</td>\n",
       "      <td>2.12976</td>\n",
       "      <td>2.12976</td>\n",
       "      <td>2.12976</td>\n",
       "      <td>2.12976</td>\n",
       "      <td>2.12976</td>\n",
       "      <td>2.12976</td>\n",
       "      <td>2.12976</td>\n",
       "      <td>...</td>\n",
       "      <td>2.12976</td>\n",
       "      <td>2.12976</td>\n",
       "      <td>2.12976</td>\n",
       "      <td>2.12976</td>\n",
       "      <td>2.12976</td>\n",
       "      <td>2.12976</td>\n",
       "      <td>2.12976</td>\n",
       "      <td>2.12976</td>\n",
       "      <td>2.12976</td>\n",
       "      <td>2.12976</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>1</th>\n",
       "      <td>2016-09-01</td>\n",
       "      <td>Carrefour</td>\n",
       "      <td>XCN0405400</td>\n",
       "      <td>0.27540</td>\n",
       "      <td>0.27540</td>\n",
       "      <td>0.27540</td>\n",
       "      <td>0.27540</td>\n",
       "      <td>0.27540</td>\n",
       "      <td>0.27540</td>\n",
       "      <td>0.27540</td>\n",
       "      <td>...</td>\n",
       "      <td>0.27540</td>\n",
       "      <td>0.27540</td>\n",
       "      <td>0.27540</td>\n",
       "      <td>0.27540</td>\n",
       "      <td>0.27540</td>\n",
       "      <td>0.27540</td>\n",
       "      <td>0.27540</td>\n",
       "      <td>0.27540</td>\n",
       "      <td>0.27540</td>\n",
       "      <td>0.27540</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>2</th>\n",
       "      <td>2016-09-01</td>\n",
       "      <td>Carrefour</td>\n",
       "      <td>XCN0413900</td>\n",
       "      <td>4.19832</td>\n",
       "      <td>4.19832</td>\n",
       "      <td>4.19832</td>\n",
       "      <td>4.19832</td>\n",
       "      <td>4.19832</td>\n",
       "      <td>4.19832</td>\n",
       "      <td>4.19832</td>\n",
       "      <td>...</td>\n",
       "      <td>4.19832</td>\n",
       "      <td>4.19832</td>\n",
       "      <td>4.19832</td>\n",
       "      <td>4.19832</td>\n",
       "      <td>4.19832</td>\n",
       "      <td>4.19832</td>\n",
       "      <td>4.19832</td>\n",
       "      <td>4.19832</td>\n",
       "      <td>4.19832</td>\n",
       "      <td>4.19832</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>3</th>\n",
       "      <td>2016-09-01</td>\n",
       "      <td>Carrefour</td>\n",
       "      <td>XCN0417509</td>\n",
       "      <td>1.89414</td>\n",
       "      <td>1.89414</td>\n",
       "      <td>1.89414</td>\n",
       "      <td>1.89414</td>\n",
       "      <td>1.89414</td>\n",
       "      <td>1.89414</td>\n",
       "      <td>1.89414</td>\n",
       "      <td>...</td>\n",
       "      <td>1.89414</td>\n",
       "      <td>1.89414</td>\n",
       "      <td>1.89414</td>\n",
       "      <td>1.89414</td>\n",
       "      <td>1.89414</td>\n",
       "      <td>1.89414</td>\n",
       "      <td>1.89414</td>\n",
       "      <td>1.89414</td>\n",
       "      <td>1.89414</td>\n",
       "      <td>1.89414</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>4</th>\n",
       "      <td>2016-09-01</td>\n",
       "      <td>Carrefour</td>\n",
       "      <td>XCN0420000</td>\n",
       "      <td>3.58020</td>\n",
       "      <td>3.58020</td>\n",
       "      <td>3.58020</td>\n",
       "      <td>3.58020</td>\n",
       "      <td>3.58020</td>\n",
       "      <td>3.58020</td>\n",
       "      <td>3.58020</td>\n",
       "      <td>...</td>\n",
       "      <td>3.58020</td>\n",
       "      <td>3.58020</td>\n",
       "      <td>3.58020</td>\n",
       "      <td>3.58020</td>\n",
       "      <td>3.58020</td>\n",
       "      <td>3.58020</td>\n",
       "      <td>3.58020</td>\n",
       "      <td>3.58020</td>\n",
       "      <td>3.58020</td>\n",
       "      <td>3.58020</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "<p>5 rows × 53 columns</p>\n",
       "</div>"
      ],
      "text/plain": [
       "  REPORT_DATE BANNER_NAME PRODUCT_BKEY      QLI  LAST_1_MONTH_AVG  \\\n",
       "0  2016-09-01   Carrefour   XCN0401200  2.12976           2.12976   \n",
       "1  2016-09-01   Carrefour   XCN0405400  0.27540           0.27540   \n",
       "2  2016-09-01   Carrefour   XCN0413900  4.19832           4.19832   \n",
       "3  2016-09-01   Carrefour   XCN0417509  1.89414           1.89414   \n",
       "4  2016-09-01   Carrefour   XCN0420000  3.58020           3.58020   \n",
       "\n",
       "   LAST_1_MONTH_MIN  LAST_1_MONTH_MAX  LAST_1_MONTH_SUM  LAST_1_MONTH_MID  \\\n",
       "0           2.12976           2.12976           2.12976           2.12976   \n",
       "1           0.27540           0.27540           0.27540           0.27540   \n",
       "2           4.19832           4.19832           4.19832           4.19832   \n",
       "3           1.89414           1.89414           1.89414           1.89414   \n",
       "4           3.58020           3.58020           3.58020           3.58020   \n",
       "\n",
       "   LAST_1_MONTH_MID25         ...           LAST_6_MONTH_MID  \\\n",
       "0             2.12976         ...                    2.12976   \n",
       "1             0.27540         ...                    0.27540   \n",
       "2             4.19832         ...                    4.19832   \n",
       "3             1.89414         ...                    1.89414   \n",
       "4             3.58020         ...                    3.58020   \n",
       "\n",
       "   LAST_6_MONTH_MID25  LAST_6_MONTH_MID75  LAST_12_MONTH_AVG  \\\n",
       "0             2.12976             2.12976            2.12976   \n",
       "1             0.27540             0.27540            0.27540   \n",
       "2             4.19832             4.19832            4.19832   \n",
       "3             1.89414             1.89414            1.89414   \n",
       "4             3.58020             3.58020            3.58020   \n",
       "\n",
       "   LAST_12_MONTH_MIN  LAST_12_MONTH_MAX  LAST_12_MONTH_SUM  LAST_12_MONTH_MID  \\\n",
       "0            2.12976            2.12976            2.12976            2.12976   \n",
       "1            0.27540            0.27540            0.27540            0.27540   \n",
       "2            4.19832            4.19832            4.19832            4.19832   \n",
       "3            1.89414            1.89414            1.89414            1.89414   \n",
       "4            3.58020            3.58020            3.58020            3.58020   \n",
       "\n",
       "   LAST_12_MONTH_MID25  LAST_12_MONTH_MID75  \n",
       "0              2.12976              2.12976  \n",
       "1              0.27540              0.27540  \n",
       "2              4.19832              4.19832  \n",
       "3              1.89414              1.89414  \n",
       "4              3.58020              3.58020  \n",
       "\n",
       "[5 rows x 53 columns]"
      ]
     },
     "execution_count": 12,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "df = pd.read_csv(basePath+ \"\\\\SalesTrend.csv\")\n",
    "df.head()\n",
    "#len(df)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "metadata": {},
   "outputs": [],
   "source": [
    "import datetime\n",
    "def getDayNo(arr):\n",
    "    count = len(arr)\n",
    "    for i in range(count):\n",
    "        dd = arr[i]\n",
    "        dd = datetime.datetime.strptime(dd,\"%Y-%m-%d\")\n",
    "\n",
    "        dayNo = dd.timetuple().tm_yday\n",
    "        arr[i] = dayNo\n",
    "        #print(\"date:\", dd, \"; dayNo: \", dayNo)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "metadata": {},
   "outputs": [],
   "source": [
    "from sklearn.metrics import mean_squared_error\n",
    "from sklearn.model_selection import train_test_split\n",
    "def plot_learning_curves(model, X, y):\n",
    "    X_train, X_val, y_train, y_val = train_test_split(X, y, test_size=0.2)\n",
    "    train_errors, val_errors = [], []\n",
    "    for m in range(1, len(X_train)):\n",
    "        model.fit(X_train[:m], y_train[:m])\n",
    "        y_train_predict = model.predict(X_train[:m])\n",
    "        y_val_predict = model.predict(X_val)\n",
    "        train_errors.append(mean_squared_error(y_train_predict, y_train[:m]))\n",
    "        val_errors.append(mean_squared_error(y_val_predict, y_val))\n",
    "    plt.plot(np.sqrt(train_errors), \"r-+\", linewidth=2, label=\"train\")\n",
    "    plt.plot(np.sqrt(val_errors), \"b-\", linewidth=3, label=\"val\")\n",
    "\n",
    "#plot_learning_curves(rf, X_train, y_train)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "metadata": {},
   "outputs": [],
   "source": [
    "def model_train(model, X_train, y_train, X_test, y_test):\n",
    "    model.fit(X_train, y_train)\n",
    "    yHat = model.predict(X_test)\n",
    "    #plt.figure(figsize=(15, 12))\n",
    "    plt.plot(range(len(yHat)), yHat, \"r-\")\n",
    "    plt.plot(range(len(y_test)), y_test, \"b--\")\n",
    "    from sklearn.metrics import mean_squared_error\n",
    "    mse = mean_squared_error(y_test, yHat)\n",
    "    print(\"mse: \", mse, \"rmse: \", np.sqrt(mse))\n",
    "    #plot_learning_curves(rf, X_train, y_train)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [
    {
     "ename": "NameError",
     "evalue": "name 'os' is not defined",
     "output_type": "error",
     "traceback": [
      "\u001b[1;31m---------------------------------------------------------------------------\u001b[0m",
      "\u001b[1;31mNameError\u001b[0m                                 Traceback (most recent call last)",
      "\u001b[1;32m<ipython-input-1-fb9670762fce>\u001b[0m in \u001b[0;36m<module>\u001b[1;34m\u001b[0m\n\u001b[0;32m      1\u001b[0m \u001b[1;31m# 获取banner聚合信息\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[1;32m----> 2\u001b[1;33m \u001b[0mfilePath\u001b[0m \u001b[1;33m=\u001b[0m \u001b[0mos\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mpath\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mjoin\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mbasePath\u001b[0m\u001b[1;33m,\u001b[0m \u001b[1;34m\"banner_group.csv\"\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0m\u001b[0;32m      3\u001b[0m \u001b[0mdf_banner\u001b[0m \u001b[1;33m=\u001b[0m \u001b[0mpd\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mread_csv\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mfilePath\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m      4\u001b[0m \u001b[1;31m#df_banner.head()\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n",
      "\u001b[1;31mNameError\u001b[0m: name 'os' is not defined"
     ]
    }
   ],
   "source": [
    "# 获取banner聚合信息\n",
    "filePath = os.path.join(basePath, \"banner_group.csv\")\n",
    "df_banner = pd.read_csv(filePath)\n",
    "#df_banner.head()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# 获取趋势聚合信息\n",
    "filePath = os.path.join(basePath, \"SalesTrend.csv\")\n",
    "df_trend = pd.read_csv(filePath)\n",
    "df_trend_carre = df_trend[df_trend.BANNER_NAME == 'Carrefour']\n",
    "df_trend_carre_agg = df_trend_carre.groupby('REPORT_DATE').mean().reset_index()\n",
    "# df_trend_carre_agg.head(3)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "print(len(df_banner), len(df_trend_carre_agg))\n",
    "df_merge = df_banner.merge(df_trend_carre_agg, on=[\"REPORT_DATE\"], how=\"inner\")\n",
    "df_merge.head(2)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "corr_matrix = df_merge.corr()\n",
    "for item in corr_matrix[\"QLI\"].sort_values(ascending=False):\n",
    "    print(item)\n",
    "corr_matrix[\"QLI\"].sort_values(ascending=False)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 17,
   "metadata": {},
   "outputs": [
    {
     "ename": "NameError",
     "evalue": "name 'df_merge' is not defined",
     "output_type": "error",
     "traceback": [
      "\u001b[1;31m---------------------------------------------------------------------------\u001b[0m",
      "\u001b[1;31mNameError\u001b[0m                                 Traceback (most recent call last)",
      "\u001b[1;32m<ipython-input-17-212ff003ae08>\u001b[0m in \u001b[0;36m<module>\u001b[1;34m\u001b[0m\n\u001b[1;32m----> 1\u001b[1;33m \u001b[0mlabels\u001b[0m \u001b[1;33m=\u001b[0m \u001b[0mdf_merge\u001b[0m\u001b[1;33m[\u001b[0m\u001b[1;34m\"QLI\"\u001b[0m\u001b[1;33m]\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mvalues\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0m\u001b[0;32m      2\u001b[0m \u001b[0mdf_merge_clean\u001b[0m \u001b[1;33m=\u001b[0m \u001b[0mdf_merge\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mcopy\u001b[0m\u001b[1;33m(\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m      3\u001b[0m \u001b[0mgetDayNo\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mdf_merge_clean\u001b[0m\u001b[1;33m[\u001b[0m\u001b[1;34m\"REPORT_DATE\"\u001b[0m\u001b[1;33m]\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mvalues\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m      4\u001b[0m \u001b[0mdf_merge_clean\u001b[0m \u001b[1;33m=\u001b[0m \u001b[0mdf_merge_clean\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mdrop\u001b[0m\u001b[1;33m(\u001b[0m\u001b[1;34m\"BANNER_NAME\"\u001b[0m\u001b[1;33m,\u001b[0m \u001b[1;36m1\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m      5\u001b[0m \u001b[0mdf_merge_clean\u001b[0m \u001b[1;33m=\u001b[0m \u001b[0mdf_merge_clean\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mdrop\u001b[0m\u001b[1;33m(\u001b[0m\u001b[1;34m\"QLI\"\u001b[0m\u001b[1;33m,\u001b[0m \u001b[1;36m1\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n",
      "\u001b[1;31mNameError\u001b[0m: name 'df_merge' is not defined"
     ]
    }
   ],
   "source": [
    "labels = df_merge[\"QLI\"].values\n",
    "df_merge_clean = df_merge.copy()\n",
    "getDayNo(df_merge_clean[\"REPORT_DATE\"].values)\n",
    "df_merge_clean = df_merge_clean.drop(\"BANNER_NAME\", 1)\n",
    "df_merge_clean = df_merge_clean.drop(\"QLI\", 1)\n",
    "df_merge_clean.head(1)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 18,
   "metadata": {},
   "outputs": [
    {
     "ename": "NameError",
     "evalue": "name 'df_merge' is not defined",
     "output_type": "error",
     "traceback": [
      "\u001b[1;31m---------------------------------------------------------------------------\u001b[0m",
      "\u001b[1;31mNameError\u001b[0m                                 Traceback (most recent call last)",
      "\u001b[1;32m<ipython-input-18-24878df581fc>\u001b[0m in \u001b[0;36m<module>\u001b[1;34m\u001b[0m\n\u001b[0;32m      1\u001b[0m \u001b[1;31m# 尝试去掉相关度低的列，但是从rmse效果不是很明显\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[1;32m----> 2\u001b[1;33m df_merge_clean = df_merge.filter(items=[\"WEEK_OF_YEAR\",\"LAST_1_MONTH_AVG\",\"LAST_1_MONTH_MID75\",\"LAST_1_MONTH_MAX\",\"LAST_1_MONTH_SUM\",\n\u001b[0m\u001b[0;32m      3\u001b[0m                                   \u001b[1;34m\"LAST_1_MONTH_MID\"\u001b[0m\u001b[1;33m,\u001b[0m\u001b[1;34m\"LAST_2_MONTH_AVG\"\u001b[0m\u001b[1;33m,\u001b[0m\u001b[1;34m\"LAST_1_MONTH_MID25\"\u001b[0m\u001b[1;33m,\u001b[0m\u001b[1;34m\"LAST_2_MONTH_MAX\"\u001b[0m\u001b[1;33m,\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m      4\u001b[0m                                   \u001b[1;34m\"LAST_2_MONTH_SUM\"\u001b[0m\u001b[1;33m,\u001b[0m\u001b[1;34m\"LAST_2_MONTH_MID75\"\u001b[0m\u001b[1;33m,\u001b[0m\u001b[1;34m\"MONTH_OF_YEAR\"\u001b[0m\u001b[1;33m,\u001b[0m\u001b[1;34m\"LAST_2_MONTH_MID\"\u001b[0m\u001b[1;33m,\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m      5\u001b[0m                                   \u001b[1;34m\"LAST_2_MONTH_MID25\"\u001b[0m\u001b[1;33m,\u001b[0m\u001b[1;34m\"LAST_3_MONTH_AVG\"\u001b[0m\u001b[1;33m,\u001b[0m\u001b[1;34m\"LAST_3_MONTH_MAX\"\u001b[0m\u001b[1;33m,\u001b[0m\u001b[1;34m\"LAST_1_MONTH_MIN\"\u001b[0m\u001b[1;33m,\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n",
      "\u001b[1;31mNameError\u001b[0m: name 'df_merge' is not defined"
     ]
    }
   ],
   "source": [
    "# 尝试去掉相关度低的列，但是从rmse效果不是很明显\n",
    "df_merge_clean = df_merge.filter(items=[\"WEEK_OF_YEAR\",\"LAST_1_MONTH_AVG\",\"LAST_1_MONTH_MID75\",\"LAST_1_MONTH_MAX\",\"LAST_1_MONTH_SUM\",\n",
    "                                  \"LAST_1_MONTH_MID\",\"LAST_2_MONTH_AVG\",\"LAST_1_MONTH_MID25\",\"LAST_2_MONTH_MAX\",\n",
    "                                  \"LAST_2_MONTH_SUM\",\"LAST_2_MONTH_MID75\",\"MONTH_OF_YEAR\",\"LAST_2_MONTH_MID\",\n",
    "                                  \"LAST_2_MONTH_MID25\",\"LAST_3_MONTH_AVG\",\"LAST_3_MONTH_MAX\",\"LAST_1_MONTH_MIN\",\n",
    "                                  \"LAST_3_MONTH_MID25\",\"LAST_3_MONTH_SUM\"])\n",
    "\n",
    "df_merge_clean.head()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 19,
   "metadata": {},
   "outputs": [
    {
     "ename": "NameError",
     "evalue": "name 'df_merge_clean' is not defined",
     "output_type": "error",
     "traceback": [
      "\u001b[1;31m---------------------------------------------------------------------------\u001b[0m",
      "\u001b[1;31mNameError\u001b[0m                                 Traceback (most recent call last)",
      "\u001b[1;32m<ipython-input-19-20b7a37df8aa>\u001b[0m in \u001b[0;36m<module>\u001b[1;34m\u001b[0m\n\u001b[1;32m----> 1\u001b[1;33m \u001b[0mlen\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mdf_merge_clean\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mvalues\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0m\u001b[0;32m      2\u001b[0m \u001b[0mX_train\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mX_test\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0my_train\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0my_test\u001b[0m \u001b[1;33m=\u001b[0m \u001b[0mtrain_test_split\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mdf_merge_clean\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mvalues\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mlabels\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mtest_size\u001b[0m\u001b[1;33m=\u001b[0m\u001b[1;36m0.2\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mrandom_state\u001b[0m\u001b[1;33m=\u001b[0m\u001b[1;36m42\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n",
      "\u001b[1;31mNameError\u001b[0m: name 'df_merge_clean' is not defined"
     ]
    }
   ],
   "source": [
    "len(df_merge_clean.values)\n",
    "X_train, X_test, y_train, y_test = train_test_split(df_merge_clean.values, labels, test_size=0.2, random_state=42)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 20,
   "metadata": {},
   "outputs": [
    {
     "ename": "NameError",
     "evalue": "name 'X_train' is not defined",
     "output_type": "error",
     "traceback": [
      "\u001b[1;31m---------------------------------------------------------------------------\u001b[0m",
      "\u001b[1;31mNameError\u001b[0m                                 Traceback (most recent call last)",
      "\u001b[1;32m<ipython-input-20-61af1cec8a6d>\u001b[0m in \u001b[0;36m<module>\u001b[1;34m\u001b[0m\n\u001b[0;32m      1\u001b[0m \u001b[1;32mfrom\u001b[0m \u001b[0msklearn\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mlinear_model\u001b[0m \u001b[1;32mimport\u001b[0m \u001b[0mLinearRegression\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m      2\u001b[0m \u001b[0mlr\u001b[0m \u001b[1;33m=\u001b[0m \u001b[0mLinearRegression\u001b[0m\u001b[1;33m(\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[1;32m----> 3\u001b[1;33m \u001b[0mmodel_train\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mlr\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mX_train\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0my_train\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mX_test\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0my_test\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0m\u001b[0;32m      4\u001b[0m \u001b[1;31m# 正常            mse:  1.740521567096297 rmse:  1.3192882805119952\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m      5\u001b[0m \u001b[1;31m# 只是取高相关的列：mse:  1.6498981200989498 rmse:  1.28448360055664\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n",
      "\u001b[1;31mNameError\u001b[0m: name 'X_train' is not defined"
     ]
    }
   ],
   "source": [
    "from sklearn.linear_model import LinearRegression\n",
    "lr = LinearRegression()\n",
    "model_train(lr, X_train, y_train, X_test, y_test)\n",
    "# 正常            mse:  1.740521567096297 rmse:  1.3192882805119952\n",
    "# 只是取高相关的列：mse:  1.6498981200989498 rmse:  1.28448360055664"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "rf = RandomForestRegressor(random_state=42,min_samples_leaf=150)\n",
    "model_train(rf, X_train, y_train, X_test, y_test)\n",
    "# 正常            mse:  1.6457994000584257 rmse:  1.282887134575145\n",
    "# 只是取高相关的列：mse:  1.80228974242949 rmse:  1.3424938519149687\n",
    "# max_features=1 :mse:  1.3514654105490587 rmse:  1.1625254451189697(正常)\n",
    "# max_features=1 :mse:  1.6524300339530966 rmse:  1.2854687992919536(去掉无关列)\n",
    "# n_estimators=100:mse:  1.556241037092118 rmse:  1.2474939026272305\n",
    "# min_samples_leaf=150: mse:  1.4474719092476627 rmse:  1.2031092673766846(如果加max_feature=1反而会略小)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from sklearn.svm import SVR\n",
    "svr =SVR()\n",
    "model_train(svr, X_train, y_train, X_test, y_test)\n",
    "# 正常            mse:  1.4925522793217523 rmse:  1.2217005686017144\n",
    "# 只是取高相关的列：mse:  1.4464301776578847 rmse:  1.2026762563790327"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "gbr = GradientBoostingRegressor()\n",
    "model_train(gbr, X_train, y_train, X_test, y_test)\n",
    "# 正常            mse:  1.7530168784915257 rmse:  1.324015437406802\n",
    "# 只是取高相关的列：mse:  1.9172499682142525 rmse:  1.3846479582241302"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 21,
   "metadata": {},
   "outputs": [
    {
     "ename": "NameError",
     "evalue": "name 'X_train' is not defined",
     "output_type": "error",
     "traceback": [
      "\u001b[1;31m---------------------------------------------------------------------------\u001b[0m",
      "\u001b[1;31mNameError\u001b[0m                                 Traceback (most recent call last)",
      "\u001b[1;32m<ipython-input-21-68605e220029>\u001b[0m in \u001b[0;36m<module>\u001b[1;34m\u001b[0m\n\u001b[0;32m      1\u001b[0m \u001b[0mabr\u001b[0m \u001b[1;33m=\u001b[0m \u001b[0mAdaBoostRegressor\u001b[0m\u001b[1;33m(\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[1;32m----> 2\u001b[1;33m \u001b[0mmodel_train\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mabr\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mX_train\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0my_train\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mX_test\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0my_test\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0m\u001b[0;32m      3\u001b[0m \u001b[1;31m# 正常            mse:  1.7041799726923388 rmse:  1.3054424432706095\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m      4\u001b[0m \u001b[1;31m# 只是取高相关的列：mse:  1.9216638377206163 rmse:  1.3862409017629715\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n",
      "\u001b[1;31mNameError\u001b[0m: name 'X_train' is not defined"
     ]
    }
   ],
   "source": [
    "abr = AdaBoostRegressor()\n",
    "model_train(abr, X_train, y_train, X_test, y_test)\n",
    "# 正常            mse:  1.7041799726923388 rmse:  1.3054424432706095\n",
    "# 只是取高相关的列：mse:  1.9216638377206163 rmse:  1.3862409017629715"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.3"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
